package cn.plugins.generator.utils

import cn.plugins.generator.entity.ColumnEntity
import cn.plugins.generator.entity.DatabaseInfoEntity
import cn.plugins.generator.entity.TableEntity
import java.sql.*


/**
 * 创建于 2020-01-08 15:48
 *
 * @author jiangyun
 * @类说明 pgsql数据库连接工具类
 */
object PostgresqlDatabaseUtil {

    private var connection: Connection? = null

    /**
     * 获取数据库连接
     */
    @Throws(Exception::class)
    fun getConnection(databaseInfo: DatabaseInfoEntity?): Connection? {
        return try {
            if (this.connection != null) {
                return this.connection
            }
            Class.forName("org.postgresql.Driver")
            DriverManager.getConnection(
                databaseInfo?.url,
                databaseInfo?.userName,
                databaseInfo?.password
            )
        } catch (e: SQLException) {
            e.printStackTrace()
            println("数据库连接失败！")
            throw e
        } catch (e: ClassNotFoundException) {
            e.printStackTrace()
            println("获取数据库驱动失败！")
            throw e
        }
    }

    /**
     * 获取表信息
     *
     * @param url           连接地址
     * @param userName      用户名
     * @param password      密码
     * @param tableName     表名称
     */
    @Throws(Exception::class)
    fun getTableInfo(url: String, userName: String, password: String, tableName: String): TableEntity? {
        val connection: Connection? = getConnection(DatabaseInfoEntity(url, userName, password))
        if (connection != null) {
            var ps: PreparedStatement? = null
            var rs: ResultSet? = null
            return try {
                ps =
                    connection.prepareStatement("select c.relname as table_name, d.description as comment from pg_catalog.pg_class c join pg_catalog.pg_description d on c.oid = d.objoid where  c.relname ='${tableName}' and d.objsubid = 0")
                rs = ps.executeQuery()

                var tableInfo = TableEntity()
                if (rs.next()) {
                    val tempTableName: String = if (tableName.isNotBlank() && tableName.substring(
                            tableName.length - 2,
                            tableName.length - 1
                        ) == "表"
                    ) {
                        tableName.substring(0, tableName.length - 2)
                    } else {
                        tableName
                    }
                    tableInfo.tableName = tempTableName
                    tableInfo.comment = rs.getString("comment")
                }

                tableInfo
            } catch (e: SQLException) {
                throw e
            } finally {
                try {
                    ps!!.close()
                    rs!!.close()
                } catch (e: SQLException) {
                    e.printStackTrace()
                }
            }
        }

        return null
    }

    /**
     * 获取表的列值信息
     *
     * @param url           连接地址
     * @param userName      用户名
     * @param password      密码
     * @param tableName     表名称
     */
    @Throws(Exception::class)
    fun findTableColumns(
        url: String,
        userName: String,
        password: String,
        tableName: String
    ): MutableList<ColumnEntity>? {
        val conn = getConnection(DatabaseInfoEntity(url, userName, password))
        var ps: PreparedStatement? = null
        var rs: ResultSet? = null
        return try {
            val sql = "select " +
                    "a.attname as columnName, " +
                    "concat_ws('', t.typname) as dataType, " +
                    "(case when a.attnotnull = true then 'true' else 'false' end) as isNullable, " +
                    "(case " +
                    "when a.attlen > 0 then a.attlen " +
                    "else a.atttypmod - 4 end) as characterMaximumLength, " +
                    "d.description as columnComment, " +
                    "(case " +
                    "when ( " +
                    "select " +
                    "count(pg_constraint.*) " +
                    "from " +
                    "pg_constraint " +
                    "inner join pg_class on " +
                    "pg_constraint.conrelid = pg_class.oid " +
                    "inner join pg_attribute on " +
                    "pg_attribute.attrelid = pg_class.oid " +
                    "and pg_attribute.attnum = any(pg_constraint.conkey) " +
                    "inner join pg_type on " +
                    "pg_type.oid = pg_attribute.atttypid " +
                    "where " +
                    "pg_class.relname = c.relname " +
                    "and pg_constraint.contype = 'p' " +
                    "and pg_attribute.attname = a.attname) > 0 then 'PRI' " +
                    "else '' end) as columnKey " +
                    "from " +
                    "pg_catalog.pg_class c, " +
                    "pg_catalog.pg_attribute a , " +
                    "pg_catalog.pg_type t, " +
                    "pg_catalog.pg_description d " +
                    "where " +
                    "c.relname = '${tableName}' " +
                    "and a.attnum>0 " +
                    "and a.attrelid = c.oid " +
                    "and a.atttypid = t.oid " +
                    "and d.objoid = a.attrelid " +
                    "and d.objsubid = a.attnum " +
                    "order by " +
                    "c.relname desc, " +
                    "a.attnum asc"
            ps = conn!!.prepareStatement(sql)
            rs = ps.executeQuery()
            val columns: MutableList<ColumnEntity> = ArrayList()
            var column: ColumnEntity
            while (rs.next()) {
                column = ColumnEntity()
                column.columnName = rs.getString("columnName")
                column.dataType = rs.getString("dataType")
                //column.columnDefault = rs.getString("columnDefault")// 目前暂时不知道怎么解决从系统表pg_attrdef中解码adbin获取默认值，待后期完善
                column.isNullable = rs.getString("isNullable")
                column.characterMaximumLength = rs.getString("characterMaximumLength")
                column.columnComment = rs.getString("columnComment")
                column.columnKey = rs.getString("columnKey")

                columns.add(column)
            }
            columns
        } catch (e: SQLException) {
            throw e
        } finally {
            try {
                ps!!.close()
                rs!!.close()
            } catch (e: SQLException) {
                e.printStackTrace()
            }
        }
    }
}